﻿#pragma kernel UpdateRopeMesh

#include "PathFrame.cginc"

struct ropeMeshData
{
    uint axis;
    float volumeScaling;
    uint stretchWithRope;
    uint spanEntireLength;
    uint instances;
    float instanceSpacing;
    float offset;
    float meshSizeAlongAxis;
    float4 scale;
};

struct smootherPathData
{
    uint smoothing;
    float decimation;
    float twist;
    float restLength;
    float smoothLength;
    uint usesOrientedParticles;
};

struct MeshData
{
    int firstVertex;
    int vertexCount;

    int firstTriangle;
    int triangleCount;
};

StructuredBuffer<int> pathSmootherIndices;
StructuredBuffer<int> chunkOffsets;

StructuredBuffer<pathFrame> frames;
StructuredBuffer<int> frameOffsets;
StructuredBuffer<int> frameCounts;

StructuredBuffer<int> vertexOffsets;

StructuredBuffer<int> meshIndices;
StructuredBuffer<MeshData> meshData;

StructuredBuffer<int> sortedIndices;
StructuredBuffer<int> sortedOffsets;

StructuredBuffer<ropeMeshData> rendererData;
StructuredBuffer<smootherPathData> pathData;

StructuredBuffer<float3> positions;
StructuredBuffer<float3> normals;
StructuredBuffer<float4> tangents;
StructuredBuffer<float4> colors;

RWByteAddressBuffer vertices;

// Variables set from the CPU
uint firstRenderer;
uint rendererCount;

pathFrame InterpolateFrames(pathFrame a, pathFrame b, float3 bOffset, float t)
{
    // this offset is used to displace a copy of the first and last frames of the path,
    // to ensure meshes extrude correctly prior to the first or past the last frame. 
    b.position += bOffset;
    pathFrame interp = addFrames(multiplyFrame(1 - t, a) ,multiplyFrame(t , b));

    // (no need to renormalize tangent, since offsetFromCurve[axis] = 0)
    interp.normal = normalize(interp.normal);
    interp.binormal = normalize(interp.binormal);
    return interp;
}

[numthreads(16, 1, 1)]
void UpdateRopeMesh (uint3 id : SV_DispatchThreadID) 
{
    unsigned int i = id.x;
    if (i >= rendererCount) return;
    
    int rendererIndex = firstRenderer + i;
    int pathIndex = pathSmootherIndices[rendererIndex];
    ropeMeshData renderer = rendererData[rendererIndex];

    // get mesh data:
    MeshData mesh = meshData[meshIndices[rendererIndex]];
    int sortedOffset = sortedOffsets[rendererIndex];

    // get index of first output vertex:
    int firstOutputVertex = vertexOffsets[rendererIndex];

    // get index of first chunk, ignore others (no support for tearing):
    int chunkIndex = chunkOffsets[pathIndex]; 

    // get first frame and frame count:
    int firstFrame = frameOffsets[chunkIndex];
    int lastFrame = firstFrame + frameCounts[chunkIndex] - 1;

    // get mesh deform axis:
    int axis = renderer.axis;

    // initialize scale vector:
    float3 actualScale = renderer.scale.xyz;

    // calculate stretch ratio:
    float stretchRatio = renderer.stretchWithRope == 1 ? pathData[chunkIndex].smoothLength / pathData[chunkIndex].restLength : 1;

    // squashing factor, makes mesh thinner when stretched and thicker when compresssed.
    float squashing = clamp(1 + renderer.volumeScaling * (1 / max(stretchRatio, 0.01f) - 1), 0.01f, 2);

    // calculate scale along swept axis so that the mesh spans the entire lenght of the rope if required.
    if (renderer.spanEntireLength == 1)
    {
        float totalMeshLength = renderer.meshSizeAlongAxis * renderer.instances;
        float totalSpacing = renderer.instanceSpacing * (renderer.instances - 1);
        float axisScale = pathData[chunkIndex].restLength / (totalMeshLength + totalSpacing);

        if (axis == 0) actualScale.x = axisScale;
        else if (axis == 1) actualScale.y = axisScale;
        else actualScale.z = axisScale;
    }

    // init loop variables:
    float lengthAlongAxis = renderer.offset;
    int index = firstFrame;
    int nextIndex = firstFrame + 1;
    int prevIndex = firstFrame;
    float nextMagnitude = distance(frames[index].position, frames[nextIndex].position);
    float prevMagnitude = nextMagnitude;

    
    for (int j = 0; j < mesh.vertexCount; ++j)
    {
        int base = (firstOutputVertex + sortedIndices[sortedOffset + j]) * 14;
        vertices.Store3(base << 2, asuint(positions[mesh.firstVertex + sortedIndices[sortedOffset + j]] * float3(0.5,1,1)));
    }

    for (int k = 0; k < (int)renderer.instances; ++k)
    {
        for (int j = 0; j < mesh.vertexCount; ++j)
        {
            int currVIndex = mesh.firstVertex + sortedIndices[sortedOffset + j]; 
            int prevVIndex = mesh.firstVertex + sortedIndices[sortedOffset + max(0,j - 1)];

            // calculate how much we've advanced in the sort axis since the last vertex:
            lengthAlongAxis += (positions[currVIndex][axis] - positions[prevVIndex][axis]) * actualScale[axis] * stretchRatio;

            // check if we have moved to a new section of the curve:
            pathFrame frame;
            if (lengthAlongAxis < 0)
            {
                while (-lengthAlongAxis > prevMagnitude && index > firstFrame)
                {
                    lengthAlongAxis += prevMagnitude;
                    index = max(index - 1, firstFrame);
                    nextIndex = min(index + 1, lastFrame);
                    prevIndex = max(index - 1, firstFrame);
                    nextMagnitude = distance(frames[index].position, frames[nextIndex].position);
                    prevMagnitude = distance(frames[index].position, frames[prevIndex].position);
                }

                float3 offset = float3(0,0,0);
                if (index == prevIndex)
                {
                    offset = frames[index].position - frames[nextIndex].position;
                    prevMagnitude = length(offset);
                }

                frame = InterpolateFrames(frames[index], frames[prevIndex], offset, -lengthAlongAxis / prevMagnitude);
            }
            else
            {
                while (lengthAlongAxis > nextMagnitude && index < lastFrame)
                {
                    lengthAlongAxis -= nextMagnitude;
                    index = min(index + 1, lastFrame);
                    nextIndex = min(index + 1, lastFrame);
                    prevIndex = max(index - 1, firstFrame);
                    nextMagnitude = distance(frames[index].position, frames[nextIndex].position);
                    prevMagnitude = distance(frames[index].position, frames[prevIndex].position);
                }

                float3 offset = float3(0,0,0);
                if (index == nextIndex)
                {
                    offset = frames[index].position - frames[prevIndex].position;
                    nextMagnitude = length(offset);
                }

                frame = InterpolateFrames(frames[index], frames[nextIndex], offset, lengthAlongAxis / nextMagnitude);
           }

            // update basis matrix:
            float3x3 basis = frame.ToMatrix(axis);

            // calculate vertex offset from curve:
            float3 offsetFromCurve = positions[currVIndex] * actualScale * frame.thickness * squashing;
            if (axis == 0) offsetFromCurve.x = 0;
            else if (axis == 1) offsetFromCurve.y = 0;
            else offsetFromCurve.z = 0;

            // write modified vertex data:
            int base = (firstOutputVertex + sortedIndices[sortedOffset + j]) * 14;
            vertices.Store3( base<<2, asuint(frame.position + mul(basis, offsetFromCurve)));
            vertices.Store3((base + 3)<<2, asuint(mul(basis, normals[currVIndex])));
            vertices.Store4((base + 6)<<2, asuint(float4(mul(basis, tangents[currVIndex].xyz), tangents[currVIndex].w)));
            vertices.Store4((base + 10)<<2, asuint(frames[index].color));
        }

        firstOutputVertex += mesh.vertexCount;
        lengthAlongAxis += renderer.instanceSpacing * actualScale[axis] * stretchRatio;
    }
   
}